import time
import sys
import datetime
import numpy as np
from tqdm import tqdm
import os.path as op
from .gdal_utils import gdal_reader
from sklearn.model_selection import train_test_split
from .func_utils import (get_dims_from_coordstxt,
                         check_outputs, 
                         check_inputs, 
                         load_depth,
                         load_npz,
                         dump_npz,
                         clip_inds,
                         dump_labels)

import matplotlib.pyplot as plt
from .generate_geo_index import generate_geo_index
import gc

# @check_inputs(inputs=["coords_txt"])
@check_outputs(stage_name="ConvertCoords", override=False)
def convert_coords_npz(**kwargs):
    """
    Kwargs:
      input_files: [path] coords.txt
      output_files: [path] coords.npz
    """
    input_files = kwargs["input_files"]
    out_files = kwargs["out_files"]

    # Read the spatial size of geotiff file
    dim_x, dim_y, geo_size = \
            get_dims_from_coordstxt(input_files["coords_txt"])

    # Load coords data, skip 8 lines, only use `Lat` and `Lon` columns
    print("Info: load coord txt, it may take some time.")
    coord_data = np.loadtxt(input_files["coords_txt"], \
                            dtype="float32",
                            skiprows=8, \
                            usecols=[2, 3])

    assert coord_data.shape[0] == dim_x*dim_y
    coord_data = coord_data.reshape(*geo_size, 2)

    dump_npz(out_files["coords_npz"], 
              coord_data=coord_data, 
              dim_x=dim_x, 
              dim_y=dim_y)

@check_inputs(inputs=["coords_npz", "depth_txt"])
@check_outputs(stage_name="GenerateIndex", override=False)
def generate_geo_index_wrapper(**kwargs):
    """
    
    """
    input_files = kwargs["input_files"]
    out_files = kwargs["out_files"]
    patch_size = kwargs["patch_size"]
    masks = kwargs["mask"]

    # Load coords_npz file
    coord_data, dim_x, dim_y = load_npz(input_files["coords_npz"])

    # Calculate the resolution of lat & lon along spatialwise
    # `scale` is the hyperparameter to extend the margin for certain lat/lon value
    scale = 5
    res_lat_y = ((coord_data[0, 0, 0] - coord_data[dim_y-1, 0, 0]) + \
           (coord_data[0, dim_x-1, 0] - coord_data[dim_y-1, dim_x-1, 0]))/(dim_y*2)
    res_lon_x = ((coord_data[0, dim_x-1, 1] - coord_data[0, 0, 1]) + \
           (coord_data[dim_y-1, dim_x-1, 1] - coord_data[dim_y-1, 0, 1]))/(dim_x*2)
    lat_margin, lon_margin = res_lat_y*scale, res_lon_x*scale

    # Load depth data
    label_coords, label_depths = load_depth(input_files["depth_txt"])

    # Build patch inds according to the patch size
    y_inds, x_inds, depths = [], [], []
    patch_inds = np.arange(-patch_size, patch_size + 1, 1).astype("int32")
    patch_inds_x, patch_inds_y = np.meshgrid(patch_inds, patch_inds)
    # cython module to do the indexing.
    y_inds, x_inds, depths = generate_geo_index(label_coords, 
                                                coord_data, 
                                                patch_inds_x,
                                                patch_inds_y,
                                                label_depths,
                                                lat_margin, 
                                                lon_margin)
    # Clip the index value out of the image or within the mask.
    y_inds, x_inds, depths = clip_inds(y_inds, x_inds,
                                       dim_y, dim_x,
                                       depths, masks)

    dump_npz(out_files["inds_npz"],
              y_inds=y_inds,
              x_inds=x_inds,
              depths=depths)

@check_inputs(inputs=["inds_npz", "geotiff"])
@check_outputs(stage_name="ExtractFeat")
def extract_features(**kwargs):
    """

    """
    input_files = kwargs["input_files"]
    out_files = kwargs["out_files"]
    masks = kwargs["mask"]

    y_inds, x_inds, depths = load_npz(input_files['inds_npz'])

    if len(y_inds) <= 0:
        return

    xsize, ysize, n_channels, data_loader = gdal_reader(input_files['geotiff'])
    # read-only object store on the disk
    geotiff_data_ro_hd = data_loader()
    # move data to memory
    print("Info: it will take some to move geotiff data to memory, please ensure your memory size > 10G.")
    geotiff_data_mem = geotiff_data_ro_hd.copy()
    feats = geotiff_data_mem[:, y_inds, x_inds].swapaxes(0, 1)
    feats = feats.reshape(feats.shape[0], -1)

    dump_npz(out_files["feat_label_npz"],
             feats=feats,
             labels=depths,
             n_channels=n_channels)

    # clean memory
    del geotiff_data_mem
    del feats
    gc.collect()

    # Magic number 90 is for the 91th channel of HSI data.
    band_to_check = geotiff_data_ro_hd[90, :, :]
    # Plot location for labels.
    dump_labels(out_files["label_png"],
                band_to_check,
                depths,
                masks,
                y_inds,
                x_inds)

@check_inputs(inputs=["feat_label_npz"])
def merge_data(**kwargs):
    """
    """
    input_files = kwargs["input_files"]
    out_files = kwargs["out_files"]
    proportion = kwargs["proportion"]

    # load feats and labels
    feats, labels, n_channels = load_npz(input_files['feat_label_npz'])
    n_trains = np.clip(int(proportion*feats.shape[0]),
                       1, feats.shape[0])
    
    feats_train, feats_val, \
    labels_train, labels_val = train_test_split(feats, labels, 
                                                train_size=n_trains, 
                                                random_state=42)
    if op.exists(out_files["train_npz"]):
        pre_feats, pre_labels, n_channels = load_npz(out_files["train_npz"])
        feats_train = np.concatenate((pre_feats, feats_train), axis=0)
        labels_train = np.concatenate((pre_labels, labels_train), axis=0)

    dump_npz(out_files["train_npz"],
             feats=feats_train,
             labels=labels_train,
             n_channels=n_channels)

    if op.exists(out_files["val_npz"]):
        pre_feats, pre_labels, n_channels = load_npz(out_files["val_npz"])
        feats_val = np.concatenate((pre_feats, feats_val), axis=0)
        labels_val = np.concatenate((pre_labels, labels_val), axis=0)

    dump_npz(out_files["val_npz"],
             feats=feats_val,
             labels=labels_val,
             n_channels=n_channels)
    
